Done By: Mudit Sharma - 21BCE2223
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
df = pd.read_csv('/content/Employee-Attrition.csv')
df
| Age | Attrition | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EmployeeCount | EmployeeNumber | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 41 | Yes | Travel_Rarely | 1102 | Sales | 1 | 2 | Life Sciences | 1 | 1 | ... | 1 | 80 | 0 | 8 | 0 | 1 | 6 | 4 | 0 | 5 |
| 1 | 49 | No | Travel_Frequently | 279 | Research & Development | 8 | 1 | Life Sciences | 1 | 2 | ... | 4 | 80 | 1 | 10 | 3 | 3 | 10 | 7 | 1 | 7 |
| 2 | 37 | Yes | Travel_Rarely | 1373 | Research & Development | 2 | 2 | Other | 1 | 4 | ... | 2 | 80 | 0 | 7 | 3 | 3 | 0 | 0 | 0 | 0 |
| 3 | 33 | No | Travel_Frequently | 1392 | Research & Development | 3 | 4 | Life Sciences | 1 | 5 | ... | 3 | 80 | 0 | 8 | 3 | 3 | 8 | 7 | 3 | 0 |
| 4 | 27 | No | Travel_Rarely | 591 | Research & Development | 2 | 1 | Medical | 1 | 7 | ... | 4 | 80 | 1 | 6 | 3 | 3 | 2 | 2 | 2 | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1465 | 36 | No | Travel_Frequently | 884 | Research & Development | 23 | 2 | Medical | 1 | 2061 | ... | 3 | 80 | 1 | 17 | 3 | 3 | 5 | 2 | 0 | 3 |
| 1466 | 39 | No | Travel_Rarely | 613 | Research & Development | 6 | 1 | Medical | 1 | 2062 | ... | 1 | 80 | 1 | 9 | 5 | 3 | 7 | 7 | 1 | 7 |
| 1467 | 27 | No | Travel_Rarely | 155 | Research & Development | 4 | 3 | Life Sciences | 1 | 2064 | ... | 2 | 80 | 1 | 6 | 0 | 3 | 6 | 2 | 0 | 3 |
| 1468 | 49 | No | Travel_Frequently | 1023 | Sales | 2 | 3 | Medical | 1 | 2065 | ... | 4 | 80 | 0 | 17 | 3 | 2 | 9 | 6 | 0 | 8 |
| 1469 | 34 | No | Travel_Rarely | 628 | Research & Development | 8 | 3 | Medical | 1 | 2068 | ... | 1 | 80 | 0 | 6 | 3 | 4 | 4 | 3 | 1 | 2 |
1470 rows × 35 columns
df.head()
| Age | Attrition | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EmployeeCount | EmployeeNumber | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 41 | Yes | Travel_Rarely | 1102 | Sales | 1 | 2 | Life Sciences | 1 | 1 | ... | 1 | 80 | 0 | 8 | 0 | 1 | 6 | 4 | 0 | 5 |
| 1 | 49 | No | Travel_Frequently | 279 | Research & Development | 8 | 1 | Life Sciences | 1 | 2 | ... | 4 | 80 | 1 | 10 | 3 | 3 | 10 | 7 | 1 | 7 |
| 2 | 37 | Yes | Travel_Rarely | 1373 | Research & Development | 2 | 2 | Other | 1 | 4 | ... | 2 | 80 | 0 | 7 | 3 | 3 | 0 | 0 | 0 | 0 |
| 3 | 33 | No | Travel_Frequently | 1392 | Research & Development | 3 | 4 | Life Sciences | 1 | 5 | ... | 3 | 80 | 0 | 8 | 3 | 3 | 8 | 7 | 3 | 0 |
| 4 | 27 | No | Travel_Rarely | 591 | Research & Development | 2 | 1 | Medical | 1 | 7 | ... | 4 | 80 | 1 | 6 | 3 | 3 | 2 | 2 | 2 | 2 |
5 rows × 35 columns
df.tail()
| Age | Attrition | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EmployeeCount | EmployeeNumber | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1465 | 36 | No | Travel_Frequently | 884 | Research & Development | 23 | 2 | Medical | 1 | 2061 | ... | 3 | 80 | 1 | 17 | 3 | 3 | 5 | 2 | 0 | 3 |
| 1466 | 39 | No | Travel_Rarely | 613 | Research & Development | 6 | 1 | Medical | 1 | 2062 | ... | 1 | 80 | 1 | 9 | 5 | 3 | 7 | 7 | 1 | 7 |
| 1467 | 27 | No | Travel_Rarely | 155 | Research & Development | 4 | 3 | Life Sciences | 1 | 2064 | ... | 2 | 80 | 1 | 6 | 0 | 3 | 6 | 2 | 0 | 3 |
| 1468 | 49 | No | Travel_Frequently | 1023 | Sales | 2 | 3 | Medical | 1 | 2065 | ... | 4 | 80 | 0 | 17 | 3 | 2 | 9 | 6 | 0 | 8 |
| 1469 | 34 | No | Travel_Rarely | 628 | Research & Development | 8 | 3 | Medical | 1 | 2068 | ... | 1 | 80 | 0 | 6 | 3 | 4 | 4 | 3 | 1 | 2 |
5 rows × 35 columns
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 1470 entries, 0 to 1469 Data columns (total 35 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Age 1470 non-null int64 1 Attrition 1470 non-null object 2 BusinessTravel 1470 non-null object 3 DailyRate 1470 non-null int64 4 Department 1470 non-null object 5 DistanceFromHome 1470 non-null int64 6 Education 1470 non-null int64 7 EducationField 1470 non-null object 8 EmployeeCount 1470 non-null int64 9 EmployeeNumber 1470 non-null int64 10 EnvironmentSatisfaction 1470 non-null int64 11 Gender 1470 non-null object 12 HourlyRate 1470 non-null int64 13 JobInvolvement 1470 non-null int64 14 JobLevel 1470 non-null int64 15 JobRole 1470 non-null object 16 JobSatisfaction 1470 non-null int64 17 MaritalStatus 1470 non-null object 18 MonthlyIncome 1470 non-null int64 19 MonthlyRate 1470 non-null int64 20 NumCompaniesWorked 1470 non-null int64 21 Over18 1470 non-null object 22 OverTime 1470 non-null object 23 PercentSalaryHike 1470 non-null int64 24 PerformanceRating 1470 non-null int64 25 RelationshipSatisfaction 1470 non-null int64 26 StandardHours 1470 non-null int64 27 StockOptionLevel 1470 non-null int64 28 TotalWorkingYears 1470 non-null int64 29 TrainingTimesLastYear 1470 non-null int64 30 WorkLifeBalance 1470 non-null int64 31 YearsAtCompany 1470 non-null int64 32 YearsInCurrentRole 1470 non-null int64 33 YearsSinceLastPromotion 1470 non-null int64 34 YearsWithCurrManager 1470 non-null int64 dtypes: int64(26), object(9) memory usage: 402.1+ KB
df.describe()
| Age | DailyRate | DistanceFromHome | Education | EmployeeCount | EmployeeNumber | EnvironmentSatisfaction | HourlyRate | JobInvolvement | JobLevel | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.0 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | ... | 1470.000000 | 1470.0 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 | 1470.000000 |
| mean | 36.923810 | 802.485714 | 9.192517 | 2.912925 | 1.0 | 1024.865306 | 2.721769 | 65.891156 | 2.729932 | 2.063946 | ... | 2.712245 | 80.0 | 0.793878 | 11.279592 | 2.799320 | 2.761224 | 7.008163 | 4.229252 | 2.187755 | 4.123129 |
| std | 9.135373 | 403.509100 | 8.106864 | 1.024165 | 0.0 | 602.024335 | 1.093082 | 20.329428 | 0.711561 | 1.106940 | ... | 1.081209 | 0.0 | 0.852077 | 7.780782 | 1.289271 | 0.706476 | 6.126525 | 3.623137 | 3.222430 | 3.568136 |
| min | 18.000000 | 102.000000 | 1.000000 | 1.000000 | 1.0 | 1.000000 | 1.000000 | 30.000000 | 1.000000 | 1.000000 | ... | 1.000000 | 80.0 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 25% | 30.000000 | 465.000000 | 2.000000 | 2.000000 | 1.0 | 491.250000 | 2.000000 | 48.000000 | 2.000000 | 1.000000 | ... | 2.000000 | 80.0 | 0.000000 | 6.000000 | 2.000000 | 2.000000 | 3.000000 | 2.000000 | 0.000000 | 2.000000 |
| 50% | 36.000000 | 802.000000 | 7.000000 | 3.000000 | 1.0 | 1020.500000 | 3.000000 | 66.000000 | 3.000000 | 2.000000 | ... | 3.000000 | 80.0 | 1.000000 | 10.000000 | 3.000000 | 3.000000 | 5.000000 | 3.000000 | 1.000000 | 3.000000 |
| 75% | 43.000000 | 1157.000000 | 14.000000 | 4.000000 | 1.0 | 1555.750000 | 4.000000 | 83.750000 | 3.000000 | 3.000000 | ... | 4.000000 | 80.0 | 1.000000 | 15.000000 | 3.000000 | 3.000000 | 9.000000 | 7.000000 | 3.000000 | 7.000000 |
| max | 60.000000 | 1499.000000 | 29.000000 | 5.000000 | 1.0 | 2068.000000 | 4.000000 | 100.000000 | 4.000000 | 5.000000 | ... | 4.000000 | 80.0 | 3.000000 | 40.000000 | 6.000000 | 4.000000 | 40.000000 | 18.000000 | 15.000000 | 17.000000 |
8 rows × 26 columns
corr = df.corr()
corr
<ipython-input-168-4381f08f6434>:1: FutureWarning: The default value of numeric_only in DataFrame.corr is deprecated. In a future version, it will default to False. Select only valid columns or specify the value of numeric_only to silence this warning. corr = df.corr()
| Age | DailyRate | DistanceFromHome | Education | EmployeeCount | EmployeeNumber | EnvironmentSatisfaction | HourlyRate | JobInvolvement | JobLevel | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Age | 1.000000 | 0.010661 | -0.001686 | 0.208034 | NaN | -0.010145 | 0.010146 | 0.024287 | 0.029820 | 0.509604 | ... | 0.053535 | NaN | 0.037510 | 0.680381 | -0.019621 | -0.021490 | 0.311309 | 0.212901 | 0.216513 | 0.202089 |
| DailyRate | 0.010661 | 1.000000 | -0.004985 | -0.016806 | NaN | -0.050990 | 0.018355 | 0.023381 | 0.046135 | 0.002966 | ... | 0.007846 | NaN | 0.042143 | 0.014515 | 0.002453 | -0.037848 | -0.034055 | 0.009932 | -0.033229 | -0.026363 |
| DistanceFromHome | -0.001686 | -0.004985 | 1.000000 | 0.021042 | NaN | 0.032916 | -0.016075 | 0.031131 | 0.008783 | 0.005303 | ... | 0.006557 | NaN | 0.044872 | 0.004628 | -0.036942 | -0.026556 | 0.009508 | 0.018845 | 0.010029 | 0.014406 |
| Education | 0.208034 | -0.016806 | 0.021042 | 1.000000 | NaN | 0.042070 | -0.027128 | 0.016775 | 0.042438 | 0.101589 | ... | -0.009118 | NaN | 0.018422 | 0.148280 | -0.025100 | 0.009819 | 0.069114 | 0.060236 | 0.054254 | 0.069065 |
| EmployeeCount | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| EmployeeNumber | -0.010145 | -0.050990 | 0.032916 | 0.042070 | NaN | 1.000000 | 0.017621 | 0.035179 | -0.006888 | -0.018519 | ... | -0.069861 | NaN | 0.062227 | -0.014365 | 0.023603 | 0.010309 | -0.011240 | -0.008416 | -0.009019 | -0.009197 |
| EnvironmentSatisfaction | 0.010146 | 0.018355 | -0.016075 | -0.027128 | NaN | 0.017621 | 1.000000 | -0.049857 | -0.008278 | 0.001212 | ... | 0.007665 | NaN | 0.003432 | -0.002693 | -0.019359 | 0.027627 | 0.001458 | 0.018007 | 0.016194 | -0.004999 |
| HourlyRate | 0.024287 | 0.023381 | 0.031131 | 0.016775 | NaN | 0.035179 | -0.049857 | 1.000000 | 0.042861 | -0.027853 | ... | 0.001330 | NaN | 0.050263 | -0.002334 | -0.008548 | -0.004607 | -0.019582 | -0.024106 | -0.026716 | -0.020123 |
| JobInvolvement | 0.029820 | 0.046135 | 0.008783 | 0.042438 | NaN | -0.006888 | -0.008278 | 0.042861 | 1.000000 | -0.012630 | ... | 0.034297 | NaN | 0.021523 | -0.005533 | -0.015338 | -0.014617 | -0.021355 | 0.008717 | -0.024184 | 0.025976 |
| JobLevel | 0.509604 | 0.002966 | 0.005303 | 0.101589 | NaN | -0.018519 | 0.001212 | -0.027853 | -0.012630 | 1.000000 | ... | 0.021642 | NaN | 0.013984 | 0.782208 | -0.018191 | 0.037818 | 0.534739 | 0.389447 | 0.353885 | 0.375281 |
| JobSatisfaction | -0.004892 | 0.030571 | -0.003669 | -0.011296 | NaN | -0.046247 | -0.006784 | -0.071335 | -0.021476 | -0.001944 | ... | -0.012454 | NaN | 0.010690 | -0.020185 | -0.005779 | -0.019459 | -0.003803 | -0.002305 | -0.018214 | -0.027656 |
| MonthlyIncome | 0.497855 | 0.007707 | -0.017014 | 0.094961 | NaN | -0.014829 | -0.006259 | -0.015794 | -0.015271 | 0.950300 | ... | 0.025873 | NaN | 0.005408 | 0.772893 | -0.021736 | 0.030683 | 0.514285 | 0.363818 | 0.344978 | 0.344079 |
| MonthlyRate | 0.028051 | -0.032182 | 0.027473 | -0.026084 | NaN | 0.012648 | 0.037600 | -0.015297 | -0.016322 | 0.039563 | ... | -0.004085 | NaN | -0.034323 | 0.026442 | 0.001467 | 0.007963 | -0.023655 | -0.012815 | 0.001567 | -0.036746 |
| NumCompaniesWorked | 0.299635 | 0.038153 | -0.029251 | 0.126317 | NaN | -0.001251 | 0.012594 | 0.022157 | 0.015012 | 0.142501 | ... | 0.052733 | NaN | 0.030075 | 0.237639 | -0.066054 | -0.008366 | -0.118421 | -0.090754 | -0.036814 | -0.110319 |
| PercentSalaryHike | 0.003634 | 0.022704 | 0.040235 | -0.011111 | NaN | -0.012944 | -0.031701 | -0.009062 | -0.017205 | -0.034730 | ... | -0.040490 | NaN | 0.007528 | -0.020608 | -0.005221 | -0.003280 | -0.035991 | -0.001520 | -0.022154 | -0.011985 |
| PerformanceRating | 0.001904 | 0.000473 | 0.027110 | -0.024539 | NaN | -0.020359 | -0.029548 | -0.002172 | -0.029071 | -0.021222 | ... | -0.031351 | NaN | 0.003506 | 0.006744 | -0.015579 | 0.002572 | 0.003435 | 0.034986 | 0.017896 | 0.022827 |
| RelationshipSatisfaction | 0.053535 | 0.007846 | 0.006557 | -0.009118 | NaN | -0.069861 | 0.007665 | 0.001330 | 0.034297 | 0.021642 | ... | 1.000000 | NaN | -0.045952 | 0.024054 | 0.002497 | 0.019604 | 0.019367 | -0.015123 | 0.033493 | -0.000867 |
| StandardHours | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
| StockOptionLevel | 0.037510 | 0.042143 | 0.044872 | 0.018422 | NaN | 0.062227 | 0.003432 | 0.050263 | 0.021523 | 0.013984 | ... | -0.045952 | NaN | 1.000000 | 0.010136 | 0.011274 | 0.004129 | 0.015058 | 0.050818 | 0.014352 | 0.024698 |
| TotalWorkingYears | 0.680381 | 0.014515 | 0.004628 | 0.148280 | NaN | -0.014365 | -0.002693 | -0.002334 | -0.005533 | 0.782208 | ... | 0.024054 | NaN | 0.010136 | 1.000000 | -0.035662 | 0.001008 | 0.628133 | 0.460365 | 0.404858 | 0.459188 |
| TrainingTimesLastYear | -0.019621 | 0.002453 | -0.036942 | -0.025100 | NaN | 0.023603 | -0.019359 | -0.008548 | -0.015338 | -0.018191 | ... | 0.002497 | NaN | 0.011274 | -0.035662 | 1.000000 | 0.028072 | 0.003569 | -0.005738 | -0.002067 | -0.004096 |
| WorkLifeBalance | -0.021490 | -0.037848 | -0.026556 | 0.009819 | NaN | 0.010309 | 0.027627 | -0.004607 | -0.014617 | 0.037818 | ... | 0.019604 | NaN | 0.004129 | 0.001008 | 0.028072 | 1.000000 | 0.012089 | 0.049856 | 0.008941 | 0.002759 |
| YearsAtCompany | 0.311309 | -0.034055 | 0.009508 | 0.069114 | NaN | -0.011240 | 0.001458 | -0.019582 | -0.021355 | 0.534739 | ... | 0.019367 | NaN | 0.015058 | 0.628133 | 0.003569 | 0.012089 | 1.000000 | 0.758754 | 0.618409 | 0.769212 |
| YearsInCurrentRole | 0.212901 | 0.009932 | 0.018845 | 0.060236 | NaN | -0.008416 | 0.018007 | -0.024106 | 0.008717 | 0.389447 | ... | -0.015123 | NaN | 0.050818 | 0.460365 | -0.005738 | 0.049856 | 0.758754 | 1.000000 | 0.548056 | 0.714365 |
| YearsSinceLastPromotion | 0.216513 | -0.033229 | 0.010029 | 0.054254 | NaN | -0.009019 | 0.016194 | -0.026716 | -0.024184 | 0.353885 | ... | 0.033493 | NaN | 0.014352 | 0.404858 | -0.002067 | 0.008941 | 0.618409 | 0.548056 | 1.000000 | 0.510224 |
| YearsWithCurrManager | 0.202089 | -0.026363 | 0.014406 | 0.069065 | NaN | -0.009197 | -0.004999 | -0.020123 | 0.025976 | 0.375281 | ... | -0.000867 | NaN | 0.024698 | 0.459188 | -0.004096 | 0.002759 | 0.769212 | 0.714365 | 0.510224 | 1.000000 |
26 rows × 26 columns
df.Attrition.value_counts()
No 1233 Yes 237 Name: Attrition, dtype: int64
df.BusinessTravel.value_counts()
Travel_Rarely 1043 Travel_Frequently 277 Non-Travel 150 Name: BusinessTravel, dtype: int64
df.Department.value_counts()
Research & Development 961 Sales 446 Human Resources 63 Name: Department, dtype: int64
df.Education.value_counts()
3 572 4 398 2 282 1 170 5 48 Name: Education, dtype: int64
df.EducationField.value_counts()
Life Sciences 606 Medical 464 Marketing 159 Technical Degree 132 Other 82 Human Resources 27 Name: EducationField, dtype: int64
df.EmployeeCount.value_counts()
1 1470 Name: EmployeeCount, dtype: int64
df.isnull().any()
Age False Attrition False BusinessTravel False DailyRate False Department False DistanceFromHome False Education False EducationField False EmployeeCount False EmployeeNumber False EnvironmentSatisfaction False Gender False HourlyRate False JobInvolvement False JobLevel False JobRole False JobSatisfaction False MaritalStatus False MonthlyIncome False MonthlyRate False NumCompaniesWorked False Over18 False OverTime False PercentSalaryHike False PerformanceRating False RelationshipSatisfaction False StandardHours False StockOptionLevel False TotalWorkingYears False TrainingTimesLastYear False WorkLifeBalance False YearsAtCompany False YearsInCurrentRole False YearsSinceLastPromotion False YearsWithCurrManager False dtype: bool
df.isnull().sum()
Age 0 Attrition 0 BusinessTravel 0 DailyRate 0 Department 0 DistanceFromHome 0 Education 0 EducationField 0 EmployeeCount 0 EmployeeNumber 0 EnvironmentSatisfaction 0 Gender 0 HourlyRate 0 JobInvolvement 0 JobLevel 0 JobRole 0 JobSatisfaction 0 MaritalStatus 0 MonthlyIncome 0 MonthlyRate 0 NumCompaniesWorked 0 Over18 0 OverTime 0 PercentSalaryHike 0 PerformanceRating 0 RelationshipSatisfaction 0 StandardHours 0 StockOptionLevel 0 TotalWorkingYears 0 TrainingTimesLastYear 0 WorkLifeBalance 0 YearsAtCompany 0 YearsInCurrentRole 0 YearsSinceLastPromotion 0 YearsWithCurrManager 0 dtype: int64
# As there are no null values we can continue with Data Visualization.
sns.pairplot(df)
plt.figure(figsize=(20, 15))
plt.show()
<Figure size 2000x1500 with 0 Axes>
plt.subplots(figsize=(20,15))
sns.heatmap(corr,annot=True)
<Axes: >
plt.figure(figsize=(8, 6))
sns.countplot(data=df, x='Attrition')
plt.title('Attrition Count')
plt.xlabel('Attrition')
plt.ylabel('Count')
plt.show()
plt.figure(figsize=(10, 6))
sns.boxplot(data=df, x='Department', y='MonthlyIncome', hue='Attrition')
plt.title('Monthly Income by Department and Attrition')
plt.xlabel('Department')
plt.ylabel('Monthly Income')
plt.xticks(rotation=45)
plt.show()
plt.figure(figsize=(12, 6))
sns.countplot(data=df, x='JobRole', hue='Attrition')
plt.title('Attrition by Job Role')
plt.xlabel('Job Role')
plt.ylabel('Count')
plt.xticks(rotation=90)
plt.show()
plt.figure(figsize=(8, 6))
sns.histplot(data=df, x='TotalWorkingYears', kde=True)
plt.title('Distribution of Total Working Years')
plt.xlabel('Total Working Years')
plt.ylabel('Frequency')
plt.show()
plt.figure(figsize=(10, 6))
sns.boxplot(data=df, x='JobSatisfaction', y='YearsInCurrentRole', hue='Attrition')
plt.title('Years in Current Role by Job Satisfaction and Attrition')
plt.xlabel('Job Satisfaction')
plt.ylabel('Years in Current Role')
plt.show()
plt.figure(figsize=(8, 6))
sns.scatterplot(data=df, x='YearsInCurrentRole', y='YearsAtCompany', hue='Attrition')
plt.title('Years in Current Role vs. Years at Company')
plt.xlabel('Years in Current Role')
plt.ylabel('Years at Company')
plt.show()
plt.figure(figsize=(8, 6))
sns.swarmplot(data=df, x='Attrition', y='DailyRate')
plt.title('Daily Rate vs. Attrition')
plt.xlabel('Attrition')
plt.ylabel('Daily Rate')
plt.show()
business_travel_counts = df['BusinessTravel'].value_counts()
plt.figure(figsize=(8, 8))
plt.pie(business_travel_counts, labels=business_travel_counts.index, autopct='%1.1f%%', startangle=90)
plt.title('Business Travel Distribution')
plt.show()
plt.figure(figsize=(10, 6))
sns.lineplot(data=df, x='Age', y='JobSatisfaction', hue='Attrition', ci=None)
plt.title('Job Satisfaction vs. Age')
plt.xlabel('Age')
plt.ylabel('Job Satisfaction')
plt.show()
<ipython-input-187-587676ac6f43>:2: FutureWarning: The `ci` parameter is deprecated. Use `errorbar=None` for the same effect. sns.lineplot(data=df, x='Age', y='JobSatisfaction', hue='Attrition', ci=None)
plt.figure(figsize=(10, 6))
sns.violinplot(data=df, x='MaritalStatus', y='TrainingTimesLastYear', hue='Attrition', split=True)
plt.title('Training Times Last Year by Marital Status and Attrition')
plt.xlabel('Marital Status')
plt.ylabel('Training Times Last Year')
plt.xticks(rotation=45)
plt.show()
plt.figure(figsize=(8, 6))
sns.scatterplot(data=df, x='YearsWithCurrManager', y='YearsAtCompany', hue='Attrition')
plt.title('Years With Current Manager vs. Years at Company')
plt.xlabel('Years With Current Manager')
plt.ylabel('Years at Company')
plt.show()
plt.figure(figsize=(10, 6))
sns.boxplot(data=df, x='Gender', y='YearsSinceLastPromotion', hue='Attrition')
plt.title('Years Since Last Promotion by Gender and Attrition')
plt.xlabel('Gender')
plt.ylabel('Years Since Last Promotion')
plt.show()
overtime_counts = df['OverTime'].value_counts()
plt.figure(figsize=(8, 8))
plt.pie(overtime_counts, labels=overtime_counts.index, autopct='%1.1f%%', startangle=90)
plt.title('OverTime Distribution')
plt.show()
plt.figure(figsize=(10, 6))
sns.lineplot(data=df, x='Age', y='MonthlyIncome', hue='Attrition', ci=None)
plt.title('Monthly Income vs. Age')
plt.xlabel('Age')
plt.ylabel('Monthly Income')
plt.show()
<ipython-input-192-be4c585955f0>:2: FutureWarning: The `ci` parameter is deprecated. Use `errorbar=None` for the same effect. sns.lineplot(data=df, x='Age', y='MonthlyIncome', hue='Attrition', ci=None)
# Label Encoding
from sklearn.preprocessing import LabelEncoder
ordinal_categorical_cols = ['BusinessTravel', 'Department', 'EducationField', 'Gender', 'JobRole', 'MaritalStatus', 'Over18', 'OverTime','Attrition','Education', 'EnvironmentSatisfaction', 'JobInvolvement', 'JobLevel', 'JobSatisfaction', 'PerformanceRating', 'RelationshipSatisfaction', 'StockOptionLevel', 'WorkLifeBalance']
df_encoded = df.copy()
label_encoder = LabelEncoder()
for col in ordinal_categorical_cols:
df_encoded[col] = label_encoder.fit_transform(df_encoded[col])
df_encoded
| Age | Attrition | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EmployeeCount | EmployeeNumber | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 41 | 1 | 2 | 1102 | 2 | 1 | 1 | 1 | 1 | 1 | ... | 0 | 80 | 0 | 8 | 0 | 0 | 6 | 4 | 0 | 5 |
| 1 | 49 | 0 | 1 | 279 | 1 | 8 | 0 | 1 | 1 | 2 | ... | 3 | 80 | 1 | 10 | 3 | 2 | 10 | 7 | 1 | 7 |
| 2 | 37 | 1 | 2 | 1373 | 1 | 2 | 1 | 4 | 1 | 4 | ... | 1 | 80 | 0 | 7 | 3 | 2 | 0 | 0 | 0 | 0 |
| 3 | 33 | 0 | 1 | 1392 | 1 | 3 | 3 | 1 | 1 | 5 | ... | 2 | 80 | 0 | 8 | 3 | 2 | 8 | 7 | 3 | 0 |
| 4 | 27 | 0 | 2 | 591 | 1 | 2 | 0 | 3 | 1 | 7 | ... | 3 | 80 | 1 | 6 | 3 | 2 | 2 | 2 | 2 | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1465 | 36 | 0 | 1 | 884 | 1 | 23 | 1 | 3 | 1 | 2061 | ... | 2 | 80 | 1 | 17 | 3 | 2 | 5 | 2 | 0 | 3 |
| 1466 | 39 | 0 | 2 | 613 | 1 | 6 | 0 | 3 | 1 | 2062 | ... | 0 | 80 | 1 | 9 | 5 | 2 | 7 | 7 | 1 | 7 |
| 1467 | 27 | 0 | 2 | 155 | 1 | 4 | 2 | 1 | 1 | 2064 | ... | 1 | 80 | 1 | 6 | 0 | 2 | 6 | 2 | 0 | 3 |
| 1468 | 49 | 0 | 1 | 1023 | 2 | 2 | 2 | 3 | 1 | 2065 | ... | 3 | 80 | 0 | 17 | 3 | 1 | 9 | 6 | 0 | 8 |
| 1469 | 34 | 0 | 2 | 628 | 1 | 8 | 2 | 3 | 1 | 2068 | ... | 0 | 80 | 0 | 6 | 3 | 3 | 4 | 3 | 1 | 2 |
1470 rows × 35 columns
# Outlier Detection & Removal using Z-Score Test.
from scipy import stats
columns_to_check = [
'Age', 'DailyRate', 'MonthlyIncome', 'DistanceFromHome',
'YearsAtCompany', 'HourlyRate', 'PercentSalaryHike',
'TotalWorkingYears', 'TrainingTimesLastYear',
'YearsInCurrentRole', 'YearsSinceLastPromotion', 'YearsWithCurrManager'
]
def remove_outliers_zscore(df, columns, z_threshold=3):
df_no_outliers = df.copy()
for column in columns:
z_scores = np.abs(stats.zscore(df_no_outliers[column]))
df_no_outliers = df_no_outliers[(z_scores < z_threshold)]
return df_no_outliers
cleaned_data = remove_outliers_zscore(df_encoded, columns_to_check)
num_removed_outliers = len(df_encoded) - len(cleaned_data)
print(f"Number of outliers removed: {num_removed_outliers}")
print(f"New dataset size after removing outliers: {cleaned_data.shape}")
Number of outliers removed: 107 New dataset size after removing outliers: (1363, 35)
# Cleaned encoded data set with no outliers.
cleaned_data
| Age | Attrition | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EmployeeCount | EmployeeNumber | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 41 | 1 | 2 | 1102 | 2 | 1 | 1 | 1 | 1 | 1 | ... | 0 | 80 | 0 | 8 | 0 | 0 | 6 | 4 | 0 | 5 |
| 1 | 49 | 0 | 1 | 279 | 1 | 8 | 0 | 1 | 1 | 2 | ... | 3 | 80 | 1 | 10 | 3 | 2 | 10 | 7 | 1 | 7 |
| 2 | 37 | 1 | 2 | 1373 | 1 | 2 | 1 | 4 | 1 | 4 | ... | 1 | 80 | 0 | 7 | 3 | 2 | 0 | 0 | 0 | 0 |
| 3 | 33 | 0 | 1 | 1392 | 1 | 3 | 3 | 1 | 1 | 5 | ... | 2 | 80 | 0 | 8 | 3 | 2 | 8 | 7 | 3 | 0 |
| 4 | 27 | 0 | 2 | 591 | 1 | 2 | 0 | 3 | 1 | 7 | ... | 3 | 80 | 1 | 6 | 3 | 2 | 2 | 2 | 2 | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1465 | 36 | 0 | 1 | 884 | 1 | 23 | 1 | 3 | 1 | 2061 | ... | 2 | 80 | 1 | 17 | 3 | 2 | 5 | 2 | 0 | 3 |
| 1466 | 39 | 0 | 2 | 613 | 1 | 6 | 0 | 3 | 1 | 2062 | ... | 0 | 80 | 1 | 9 | 5 | 2 | 7 | 7 | 1 | 7 |
| 1467 | 27 | 0 | 2 | 155 | 1 | 4 | 2 | 1 | 1 | 2064 | ... | 1 | 80 | 1 | 6 | 0 | 2 | 6 | 2 | 0 | 3 |
| 1468 | 49 | 0 | 1 | 1023 | 2 | 2 | 2 | 3 | 1 | 2065 | ... | 3 | 80 | 0 | 17 | 3 | 1 | 9 | 6 | 0 | 8 |
| 1469 | 34 | 0 | 2 | 628 | 1 | 8 | 2 | 3 | 1 | 2068 | ... | 0 | 80 | 0 | 6 | 3 | 3 | 4 | 3 | 1 | 2 |
1363 rows × 35 columns
# Splitting Dependent and Independent variables
target_column = 'Attrition'
y = cleaned_data[target_column]
X = cleaned_data.drop(columns=[target_column])
X.head()
| Age | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EmployeeCount | EmployeeNumber | EnvironmentSatisfaction | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 41 | 2 | 1102 | 2 | 1 | 1 | 1 | 1 | 1 | 1 | ... | 0 | 80 | 0 | 8 | 0 | 0 | 6 | 4 | 0 | 5 |
| 1 | 49 | 1 | 279 | 1 | 8 | 0 | 1 | 1 | 2 | 2 | ... | 3 | 80 | 1 | 10 | 3 | 2 | 10 | 7 | 1 | 7 |
| 2 | 37 | 2 | 1373 | 1 | 2 | 1 | 4 | 1 | 4 | 3 | ... | 1 | 80 | 0 | 7 | 3 | 2 | 0 | 0 | 0 | 0 |
| 3 | 33 | 1 | 1392 | 1 | 3 | 3 | 1 | 1 | 5 | 3 | ... | 2 | 80 | 0 | 8 | 3 | 2 | 8 | 7 | 3 | 0 |
| 4 | 27 | 2 | 591 | 1 | 2 | 0 | 3 | 1 | 7 | 0 | ... | 3 | 80 | 1 | 6 | 3 | 2 | 2 | 2 | 2 | 2 |
5 rows × 34 columns
y.head()
0 1 1 0 2 1 3 0 4 0 Name: Attrition, dtype: int64
cleaned_data.shape
(1363, 35)
X.shape
(1363, 34)
# Splitting Data into Train and Test Set
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("X_train shape:", X_train.shape)
print("X_test shape:", X_test.shape)
print("y_train shape:", y_train.shape)
print("y_test shape:", y_test.shape)
X_train shape: (1090, 34) X_test shape: (273, 34) y_train shape: (1090,) y_test shape: (273,)
X_train.head()
| Age | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EmployeeCount | EmployeeNumber | EnvironmentSatisfaction | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1085 | 31 | 1 | 561 | 1 | 3 | 2 | 1 | 1 | 1537 | 3 | ... | 0 | 80 | 0 | 7 | 2 | 0 | 7 | 2 | 7 | 7 |
| 1393 | 27 | 2 | 954 | 2 | 9 | 2 | 2 | 1 | 1965 | 3 | ... | 0 | 80 | 0 | 7 | 5 | 2 | 7 | 7 | 0 | 7 |
| 911 | 25 | 1 | 599 | 2 | 24 | 0 | 1 | 1 | 1273 | 2 | ... | 3 | 80 | 0 | 1 | 4 | 2 | 1 | 0 | 1 | 0 |
| 1027 | 34 | 2 | 401 | 1 | 1 | 2 | 1 | 1 | 1447 | 3 | ... | 0 | 80 | 1 | 7 | 2 | 1 | 5 | 4 | 0 | 2 |
| 56 | 35 | 1 | 853 | 2 | 18 | 4 | 1 | 1 | 74 | 1 | ... | 3 | 80 | 1 | 9 | 3 | 1 | 9 | 8 | 1 | 8 |
5 rows × 34 columns
X_test.head()
| Age | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EmployeeCount | EmployeeNumber | EnvironmentSatisfaction | ... | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 460 | 26 | 2 | 775 | 2 | 29 | 1 | 3 | 1 | 618 | 0 | ... | 0 | 80 | 2 | 8 | 5 | 2 | 0 | 0 | 0 | 0 |
| 1006 | 49 | 1 | 1475 | 1 | 28 | 1 | 1 | 1 | 1420 | 0 | ... | 0 | 80 | 0 | 20 | 2 | 2 | 4 | 3 | 1 | 3 |
| 843 | 26 | 2 | 1384 | 1 | 3 | 3 | 3 | 1 | 1177 | 0 | ... | 1 | 80 | 1 | 8 | 2 | 2 | 8 | 7 | 0 | 7 |
| 486 | 37 | 2 | 558 | 2 | 2 | 2 | 2 | 1 | 656 | 3 | ... | 2 | 80 | 1 | 17 | 3 | 1 | 3 | 0 | 1 | 0 |
| 461 | 35 | 2 | 195 | 2 | 1 | 2 | 3 | 1 | 620 | 0 | ... | 3 | 80 | 0 | 5 | 3 | 2 | 5 | 4 | 0 | 3 |
5 rows × 34 columns
y_train.head()
1085 1 1393 0 911 1 1027 0 56 0 Name: Attrition, dtype: int64
y_test.head()
460 0 1006 1 843 0 486 0 461 0 Name: Attrition, dtype: int64
# Feature Scaling - Standardization (Z-score Scaling)
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_standardized = scaler.fit_transform(X_train)
X_test_standardized = scaler.transform(X_test)
X_train_standardized
array([[-0.57497515, -0.93717743, -0.60471296, ..., -0.56597211,
2.40499708, 1.00332169],
[-1.03218684, 0.58294667, 0.37446629, ..., 0.97630188,
-0.71630378, 1.00332169],
[-1.26079268, -0.93717743, -0.51003405, ..., -1.1828817 ,
-0.27040365, -1.1616513 ],
...,
[-0.57497515, -0.93717743, -0.12384376, ..., 0.97630188,
-0.71630378, 1.31260354],
[ 1.71108327, -0.93717743, 0.40934799, ..., -1.1828817 ,
2.40499708, 1.00332169],
[-0.11776347, -0.93717743, -1.63870631, ..., 0.05093749,
-0.27040365, 0.07547612]])
X_test_standardized
array([[-1.14648976, 0.58294667, -0.07152121, ..., -1.1828817 ,
-0.71630378, -1.1616513 ],
[ 1.48247743, -0.93717743, 1.67256396, ..., -0.25751731,
-0.27040365, -0.23380573],
[-1.14648976, 0.58294667, 1.44583289, ..., 0.97630188,
-0.71630378, 1.00332169],
...,
[ 1.13956867, 0.58294667, -0.78659613, ..., 0.97630188,
2.40499708, 1.00332169],
[-1.60370145, 0.58294667, -1.37211043, ..., -0.8744269 ,
-0.27040365, -0.54308759],
[ 2.16829496, 0.58294667, 1.58785125, ..., -0.56597211,
-0.27040365, -0.54308759]])
plt.figure(figsize=(15, 10))
sns.heatmap(X_train_standardized, cmap='coolwarm', cbar=True, annot=False)
plt.title('Scaled Training Dataset')
plt.xlabel('Feature Index')
plt.ylabel('Sample Index')
plt.show()
# Logistic Regression Model
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score,confusion_matrix,classification_report,roc_auc_score,roc_curve
logistic_regression_model = LogisticRegression(random_state=42)
logistic_regression_model.fit(X_train_standardized, y_train)
y_pred_logistic_regression = logistic_regression_model.predict(X_test_standardized)
accuracy_lr = accuracy_score(y_test, y_pred_logistic_regression)
print("Logistic Regression Accuracy:", accuracy_lr)
print("Classification Report:\n", classification_report(y_test, y_pred_logistic_regression))
Logistic Regression Accuracy: 0.8754578754578755
Classification Report:
precision recall f1-score support
0 0.89 0.97 0.93 229
1 0.71 0.39 0.50 44
accuracy 0.88 273
macro avg 0.80 0.68 0.71 273
weighted avg 0.86 0.88 0.86 273
probability = logistic_regression_model.predict_proba(X_test_standardized)[:,1]
# Roc_Curve
fpr,tpr,threshsholds = roc_curve(y_test,probability)
probability
array([2.28752232e-01, 4.41662043e-01, 1.15628793e-02, 1.71529396e-01,
6.12069188e-02, 1.01788372e-01, 9.18946542e-02, 2.30353084e-02,
1.02584502e-01, 1.80158771e-01, 1.87531530e-02, 1.96761622e-01,
1.11744505e-01, 1.18633978e-03, 1.53076817e-03, 2.59736865e-01,
3.71401092e-01, 2.20064943e-02, 5.65828781e-01, 2.59959657e-02,
5.82864836e-01, 7.37699063e-02, 5.46040881e-02, 7.11503570e-02,
6.22476276e-02, 6.76780428e-03, 3.38328916e-01, 1.79466370e-02,
3.55684590e-01, 2.52160690e-01, 4.75278624e-03, 9.02476430e-04,
3.20784546e-02, 1.09778329e-02, 2.66911378e-05, 9.46861513e-03,
1.69684078e-01, 2.85964966e-02, 1.77022479e-02, 1.01254419e-01,
1.87570635e-02, 3.10878978e-02, 1.53062038e-03, 2.02927837e-01,
3.36609817e-01, 2.55579793e-01, 2.97477252e-02, 6.80416729e-01,
1.19145507e-01, 3.41558986e-03, 6.23240447e-02, 3.51089838e-02,
1.82767021e-01, 1.86477915e-01, 1.03135389e-01, 6.52019244e-01,
1.89871820e-01, 1.06515940e-02, 3.85757099e-03, 6.90560102e-02,
7.87288776e-02, 1.75708531e-01, 3.71441057e-02, 1.72562296e-01,
6.04281653e-01, 1.88352463e-01, 3.07216078e-01, 3.72955329e-02,
9.27430404e-01, 1.72053969e-02, 3.83384349e-02, 2.14113210e-02,
7.97532844e-01, 4.61926669e-01, 4.78685754e-03, 7.35784169e-02,
1.93633014e-01, 1.72771894e-01, 3.08593808e-01, 4.70050203e-01,
2.25006025e-01, 4.41668390e-01, 1.42142678e-01, 7.25995408e-02,
2.78217500e-02, 3.44766558e-01, 2.96054364e-03, 2.79457357e-02,
8.27029124e-01, 3.00208474e-02, 7.08173425e-03, 4.94331409e-02,
5.09298172e-02, 8.09542065e-03, 1.80679919e-01, 4.13991338e-03,
2.22414917e-02, 7.10468946e-02, 1.72807360e-01, 2.87152699e-02,
1.41539013e-01, 2.72905977e-01, 1.85627382e-01, 5.39230520e-01,
4.00579140e-02, 7.51434420e-03, 7.66947757e-01, 1.05756058e-04,
6.14460092e-02, 2.59459903e-01, 2.09635661e-01, 3.65073083e-02,
3.60835922e-02, 1.34513906e-02, 2.76321528e-02, 1.02632878e-01,
7.13643516e-02, 2.26633304e-02, 3.79165619e-01, 1.17660323e-02,
1.86965506e-02, 4.59318002e-03, 8.64819365e-02, 8.67101410e-02,
4.04618184e-01, 4.49936142e-03, 7.84642480e-01, 7.68614624e-03,
2.46761987e-01, 1.40866995e-02, 1.68008182e-02, 2.00183295e-01,
1.49818154e-02, 3.54096884e-01, 9.44301983e-02, 1.22997778e-02,
1.96743140e-01, 1.26458928e-01, 1.37036145e-01, 1.89524973e-02,
5.74193005e-03, 6.98309232e-04, 4.78148792e-01, 3.28780065e-02,
8.83218768e-01, 9.96394938e-02, 1.69049802e-01, 1.63906582e-02,
1.03257157e-01, 2.06723865e-02, 8.88703156e-03, 7.85571511e-04,
8.83124268e-02, 1.04460705e-01, 1.20227722e-02, 2.28040233e-02,
1.03376797e-01, 6.51861532e-02, 9.86891946e-02, 7.69415434e-01,
1.90309445e-01, 1.89865136e-01, 4.65089982e-02, 9.50571913e-04,
2.23330068e-03, 1.51754525e-02, 6.30147117e-02, 7.11843396e-02,
1.16295019e-02, 4.82748530e-01, 8.36788764e-01, 8.60930877e-02,
5.76152042e-02, 4.36987712e-03, 1.48032187e-01, 4.04810680e-02,
3.61311150e-01, 6.34284158e-01, 9.18722792e-02, 7.53599472e-01,
6.71373634e-02, 3.38652943e-03, 2.20739857e-01, 5.21066655e-02,
1.08360491e-02, 3.75681795e-02, 8.36758922e-01, 1.64010928e-01,
5.88963521e-02, 1.70336260e-01, 1.10723613e-02, 1.81442902e-01,
1.89710786e-01, 7.56221733e-02, 2.46294538e-01, 1.80700127e-01,
1.45563981e-01, 2.60651624e-02, 3.94900209e-01, 1.75729631e-01,
4.47770145e-02, 6.51497035e-01, 4.81132112e-01, 1.50463180e-01,
9.00398658e-02, 2.86333345e-02, 1.43790075e-01, 8.11070503e-02,
7.86090184e-03, 2.75119334e-01, 2.30145959e-02, 3.04472682e-02,
1.91670971e-01, 5.63515718e-01, 1.24625239e-02, 1.54863452e-01,
4.99989984e-02, 1.45744978e-02, 2.29290260e-01, 4.06959980e-02,
2.90471323e-03, 2.27623947e-01, 7.57521466e-03, 3.23914074e-03,
2.01898018e-01, 5.22192289e-02, 2.25643547e-01, 3.61179436e-03,
1.76107353e-02, 1.36551493e-03, 1.84413565e-01, 5.93987553e-01,
1.12875140e-02, 3.48945581e-02, 1.68762983e-01, 2.35019134e-02,
9.03907432e-02, 3.51703006e-01, 4.72577550e-02, 3.16199375e-01,
4.72217627e-02, 5.65045216e-02, 1.34219019e-01, 1.45155868e-01,
3.24385058e-01, 3.16144749e-01, 5.06443773e-01, 1.72662861e-02,
4.88323743e-01, 1.76001303e-02, 3.31255886e-02, 1.13367615e-01,
2.74673248e-02, 1.37015004e-01, 9.39300936e-03, 5.46158350e-01,
3.12255295e-01, 4.68013760e-02, 4.78503471e-02, 6.30612667e-01,
2.11039727e-01, 5.13981226e-02, 1.54494214e-01, 3.95704238e-03,
2.13222733e-02, 1.57431349e-02, 5.93934876e-01, 1.98066436e-01,
2.72647470e-01, 3.15238788e-01, 1.81009421e-01, 2.94403110e-01,
1.69506884e-01])
plt.plot(fpr,tpr)
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.title('ROC CURVE')
plt.show()
# Decision Tree Model
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score,confusion_matrix,classification_report,roc_auc_score,roc_curve
decision_tree_model = DecisionTreeClassifier(random_state=42)
decision_tree_model.fit(X_train_standardized, y_train)
y_pred_decision_tree = decision_tree_model.predict(X_test_standardized)
accuracy_dt = accuracy_score(y_test, y_pred_decision_tree)
print("Decision Tree Accuracy:", accuracy_dt)
print("Classification Report:\n", classification_report(y_test, y_pred_decision_tree))
Decision Tree Accuracy: 0.7728937728937729
Classification Report:
precision recall f1-score support
0 0.86 0.87 0.87 229
1 0.29 0.27 0.28 44
accuracy 0.77 273
macro avg 0.57 0.57 0.57 273
weighted avg 0.77 0.77 0.77 273
pred=decision_tree_model.predict(X_test_standardized)
pred
array([1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 0])
confusion_matrix(y_test,pred)
array([[199, 30],
[ 32, 12]])
pd.crosstab(y_test,pred)
| col_0 | 0 | 1 |
|---|---|---|
| Attrition | ||
| 0 | 199 | 30 |
| 1 | 32 | 12 |
probability=decision_tree_model.predict_proba(X_test_standardized)[:,1]
probability
array([1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0.,
1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1.,
1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
0.])
# Roc_Curve
fpr,tpr,threshsholds = roc_curve(y_test,probability)
plt.plot(fpr,tpr)
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.title('ROC CURVE')
plt.show()
from sklearn import tree
plt.figure(figsize=(25,15))
tree.plot_tree(decision_tree_model,filled=True)
[Text(0.4673406862745098, 0.9642857142857143, 'x[21] <= 0.498\ngini = 0.279\nsamples = 1090\nvalue = [907, 183]'), Text(0.20147058823529412, 0.8928571428571429, 'x[27] <= -1.161\ngini = 0.196\nsamples = 788\nvalue = [701, 87]'), Text(0.06274509803921569, 0.8214285714285714, 'x[11] <= -0.34\ngini = 0.466\nsamples = 73\nvalue = [46, 27]'), Text(0.03137254901960784, 0.75, 'x[9] <= -0.214\ngini = 0.459\nsamples = 28\nvalue = [10, 18]'), Text(0.01568627450980392, 0.6785714285714286, 'x[30] <= -1.207\ngini = 0.142\nsamples = 13\nvalue = [1, 12]'), Text(0.00784313725490196, 0.6071428571428571, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.023529411764705882, 0.6071428571428571, 'gini = 0.0\nsamples = 12\nvalue = [0, 12]'), Text(0.047058823529411764, 0.6785714285714286, 'x[17] <= -0.915\ngini = 0.48\nsamples = 15\nvalue = [9, 6]'), Text(0.0392156862745098, 0.6071428571428571, 'x[17] <= -1.09\ngini = 0.375\nsamples = 8\nvalue = [2, 6]'), Text(0.03137254901960784, 0.5357142857142857, 'x[2] <= -0.167\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.023529411764705882, 0.4642857142857143, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.0392156862745098, 0.4642857142857143, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.047058823529411764, 0.5357142857142857, 'gini = 0.0\nsamples = 5\nvalue = [0, 5]'), Text(0.054901960784313725, 0.6071428571428571, 'gini = 0.0\nsamples = 7\nvalue = [7, 0]'), Text(0.09411764705882353, 0.75, 'x[5] <= 0.591\ngini = 0.32\nsamples = 45\nvalue = [36, 9]'), Text(0.0784313725490196, 0.6785714285714286, 'x[11] <= 0.702\ngini = 0.25\nsamples = 41\nvalue = [35, 6]'), Text(0.07058823529411765, 0.6071428571428571, 'gini = 0.0\nsamples = 21\nvalue = [21, 0]'), Text(0.08627450980392157, 0.6071428571428571, 'x[2] <= -1.256\ngini = 0.42\nsamples = 20\nvalue = [14, 6]'), Text(0.0784313725490196, 0.5357142857142857, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'), Text(0.09411764705882353, 0.5357142857142857, 'x[29] <= -1.772\ngini = 0.291\nsamples = 17\nvalue = [14, 3]'), Text(0.08627450980392157, 0.4642857142857143, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.10196078431372549, 0.4642857142857143, 'x[11] <= 1.693\ngini = 0.219\nsamples = 16\nvalue = [14, 2]'), Text(0.09411764705882353, 0.39285714285714285, 'x[8] <= 1.396\ngini = 0.124\nsamples = 15\nvalue = [14, 1]'), Text(0.08627450980392157, 0.32142857142857145, 'gini = 0.0\nsamples = 14\nvalue = [14, 0]'), Text(0.10196078431372549, 0.32142857142857145, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.10980392156862745, 0.39285714285714285, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.10980392156862745, 0.6785714285714286, 'x[11] <= 0.156\ngini = 0.375\nsamples = 4\nvalue = [1, 3]'), Text(0.10196078431372549, 0.6071428571428571, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.11764705882352941, 0.6071428571428571, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'), Text(0.34019607843137256, 0.8214285714285714, 'x[15] <= -1.104\ngini = 0.154\nsamples = 715\nvalue = [655, 60]'), Text(0.2519607843137255, 0.75, 'x[26] <= -0.344\ngini = 0.296\nsamples = 144\nvalue = [118, 26]'), Text(0.19411764705882353, 0.6785714285714286, 'x[4] <= -0.076\ngini = 0.428\nsamples = 58\nvalue = [40, 18]'), Text(0.1568627450980392, 0.6071428571428571, 'x[18] <= 0.412\ngini = 0.278\nsamples = 36\nvalue = [30, 6]'), Text(0.14901960784313725, 0.5357142857142857, 'x[19] <= 0.346\ngini = 0.444\nsamples = 18\nvalue = [12, 6]'), Text(0.13333333333333333, 0.4642857142857143, 'x[8] <= -0.726\ngini = 0.278\nsamples = 12\nvalue = [10, 2]'), Text(0.12549019607843137, 0.39285714285714285, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.1411764705882353, 0.39285714285714285, 'gini = 0.0\nsamples = 10\nvalue = [10, 0]'), Text(0.16470588235294117, 0.4642857142857143, 'x[27] <= 0.695\ngini = 0.444\nsamples = 6\nvalue = [2, 4]'), Text(0.1568627450980392, 0.39285714285714285, 'gini = 0.0\nsamples = 4\nvalue = [0, 4]'), Text(0.17254901960784313, 0.39285714285714285, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.16470588235294117, 0.5357142857142857, 'gini = 0.0\nsamples = 18\nvalue = [18, 0]'), Text(0.23137254901960785, 0.6071428571428571, 'x[19] <= -0.057\ngini = 0.496\nsamples = 22\nvalue = [10, 12]'), Text(0.21176470588235294, 0.5357142857142857, 'x[31] <= -0.103\ngini = 0.426\nsamples = 13\nvalue = [9, 4]'), Text(0.19607843137254902, 0.4642857142857143, 'x[0] <= -0.746\ngini = 0.48\nsamples = 5\nvalue = [2, 3]'), Text(0.18823529411764706, 0.39285714285714285, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.20392156862745098, 0.39285714285714285, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'), Text(0.22745098039215686, 0.4642857142857143, 'x[6] <= 1.669\ngini = 0.219\nsamples = 8\nvalue = [7, 1]'), Text(0.2196078431372549, 0.39285714285714285, 'gini = 0.0\nsamples = 7\nvalue = [7, 0]'), Text(0.23529411764705882, 0.39285714285714285, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.25098039215686274, 0.5357142857142857, 'x[0] <= 1.597\ngini = 0.198\nsamples = 9\nvalue = [1, 8]'), Text(0.24313725490196078, 0.4642857142857143, 'gini = 0.0\nsamples = 8\nvalue = [0, 8]'), Text(0.25882352941176473, 0.4642857142857143, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.30980392156862746, 0.6785714285714286, 'x[11] <= 1.693\ngini = 0.169\nsamples = 86\nvalue = [78, 8]'), Text(0.30196078431372547, 0.6071428571428571, 'x[30] <= 2.527\ngini = 0.151\nsamples = 85\nvalue = [78, 7]'), Text(0.29411764705882354, 0.5357142857142857, 'x[29] <= -1.772\ngini = 0.133\nsamples = 84\nvalue = [78, 6]'), Text(0.27450980392156865, 0.4642857142857143, 'x[12] <= -0.313\ngini = 0.48\nsamples = 5\nvalue = [3, 2]'), Text(0.26666666666666666, 0.39285714285714285, 'x[22] <= -0.886\ngini = 0.444\nsamples = 3\nvalue = [1, 2]'), Text(0.25882352941176473, 0.32142857142857145, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.27450980392156865, 0.32142857142857145, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.2823529411764706, 0.39285714285714285, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.3137254901960784, 0.4642857142857143, 'x[14] <= 1.198\ngini = 0.096\nsamples = 79\nvalue = [75, 4]'), Text(0.2980392156862745, 0.39285714285714285, 'x[18] <= -0.691\ngini = 0.075\nsamples = 77\nvalue = [74, 3]'), Text(0.2901960784313726, 0.32142857142857145, 'x[18] <= -0.7\ngini = 0.227\nsamples = 23\nvalue = [20, 3]'), Text(0.2823529411764706, 0.25, 'x[33] <= 0.849\ngini = 0.165\nsamples = 22\nvalue = [20, 2]'), Text(0.27450980392156865, 0.17857142857142858, 'gini = 0.0\nsamples = 16\nvalue = [16, 0]'), Text(0.2901960784313726, 0.17857142857142858, 'x[18] <= -1.313\ngini = 0.444\nsamples = 6\nvalue = [4, 2]'), Text(0.2823529411764706, 0.10714285714285714, 'gini = 0.0\nsamples = 4\nvalue = [4, 0]'), Text(0.2980392156862745, 0.10714285714285714, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.2980392156862745, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.3058823529411765, 0.32142857142857145, 'gini = 0.0\nsamples = 54\nvalue = [54, 0]'), Text(0.32941176470588235, 0.39285714285714285, 'x[6] <= 0.926\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.3215686274509804, 0.32142857142857145, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.33725490196078434, 0.32142857142857145, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.30980392156862746, 0.5357142857142857, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.3176470588235294, 0.6071428571428571, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.4284313725490196, 0.75, 'x[26] <= -0.344\ngini = 0.112\nsamples = 571\nvalue = [537, 34]'), Text(0.35784313725490197, 0.6785714285714286, 'x[12] <= -1.722\ngini = 0.18\nsamples = 230\nvalue = [207, 23]'), Text(0.3333333333333333, 0.6071428571428571, 'x[2] <= 0.387\ngini = 0.486\nsamples = 12\nvalue = [7, 5]'), Text(0.3254901960784314, 0.5357142857142857, 'gini = 0.0\nsamples = 7\nvalue = [7, 0]'), Text(0.3411764705882353, 0.5357142857142857, 'gini = 0.0\nsamples = 5\nvalue = [0, 5]'), Text(0.38235294117647056, 0.6071428571428571, 'x[2] <= -1.715\ngini = 0.152\nsamples = 218\nvalue = [200, 18]'), Text(0.3568627450980392, 0.5357142857142857, 'x[24] <= -0.632\ngini = 0.444\nsamples = 3\nvalue = [1, 2]'), Text(0.34901960784313724, 0.4642857142857143, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.36470588235294116, 0.4642857142857143, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.40784313725490196, 0.5357142857142857, 'x[0] <= -0.518\ngini = 0.138\nsamples = 215\nvalue = [199, 16]'), Text(0.3803921568627451, 0.4642857142857143, 'x[30] <= -0.768\ngini = 0.27\nsamples = 56\nvalue = [47, 9]'), Text(0.36470588235294116, 0.39285714285714285, 'x[19] <= 0.749\ngini = 0.444\nsamples = 9\nvalue = [3, 6]'), Text(0.3568627450980392, 0.32142857142857145, 'x[13] <= 0.592\ngini = 0.375\nsamples = 4\nvalue = [3, 1]'), Text(0.34901960784313724, 0.25, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.36470588235294116, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.37254901960784315, 0.32142857142857145, 'gini = 0.0\nsamples = 5\nvalue = [0, 5]'), Text(0.396078431372549, 0.39285714285714285, 'x[32] <= 2.182\ngini = 0.12\nsamples = 47\nvalue = [44, 3]'), Text(0.38823529411764707, 0.32142857142857145, 'x[4] <= -0.949\ngini = 0.083\nsamples = 46\nvalue = [44, 2]'), Text(0.3803921568627451, 0.25, 'x[2] <= 0.368\ngini = 0.48\nsamples = 5\nvalue = [3, 2]'), Text(0.37254901960784315, 0.17857142857142858, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.38823529411764707, 0.17857142857142858, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.396078431372549, 0.25, 'gini = 0.0\nsamples = 41\nvalue = [41, 0]'), Text(0.403921568627451, 0.32142857142857145, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.43529411764705883, 0.4642857142857143, 'x[14] <= 0.386\ngini = 0.084\nsamples = 159\nvalue = [152, 7]'), Text(0.42745098039215684, 0.39285714285714285, 'gini = 0.0\nsamples = 95\nvalue = [95, 0]'), Text(0.44313725490196076, 0.39285714285714285, 'x[0] <= 1.768\ngini = 0.195\nsamples = 64\nvalue = [57, 7]'), Text(0.42745098039215684, 0.32142857142857145, 'x[2] <= 1.661\ngini = 0.126\nsamples = 59\nvalue = [55, 4]'), Text(0.4196078431372549, 0.25, 'x[17] <= 0.833\ngini = 0.098\nsamples = 58\nvalue = [55, 3]'), Text(0.403921568627451, 0.17857142857142858, 'x[4] <= 2.418\ngini = 0.037\nsamples = 53\nvalue = [52, 1]'), Text(0.396078431372549, 0.10714285714285714, 'gini = 0.0\nsamples = 51\nvalue = [51, 0]'), Text(0.4117647058823529, 0.10714285714285714, 'x[30] <= 0.769\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.403921568627451, 0.03571428571428571, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.4196078431372549, 0.03571428571428571, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.43529411764705883, 0.17857142857142858, 'x[8] <= -0.142\ngini = 0.48\nsamples = 5\nvalue = [3, 2]'), Text(0.42745098039215684, 0.10714285714285714, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.44313725490196076, 0.10714285714285714, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.43529411764705883, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.4588235294117647, 0.32142857142857145, 'x[27] <= 0.386\ngini = 0.48\nsamples = 5\nvalue = [2, 3]'), Text(0.45098039215686275, 0.25, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'), Text(0.4666666666666667, 0.25, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.49901960784313726, 0.6785714285714286, 'x[4] <= -0.45\ngini = 0.062\nsamples = 341\nvalue = [330, 11]'), Text(0.49117647058823527, 0.6071428571428571, 'gini = 0.0\nsamples = 154\nvalue = [154, 0]'), Text(0.5068627450980392, 0.6071428571428571, 'x[6] <= -1.305\ngini = 0.111\nsamples = 187\nvalue = [176, 11]'), Text(0.47843137254901963, 0.5357142857142857, 'x[2] <= 0.042\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.47058823529411764, 0.4642857142857143, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.48627450980392156, 0.4642857142857143, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.5352941176470588, 0.5357142857142857, 'x[17] <= 0.307\ngini = 0.102\nsamples = 185\nvalue = [175, 10]'), Text(0.5019607843137255, 0.4642857142857143, 'x[11] <= -1.579\ngini = 0.055\nsamples = 142\nvalue = [138, 4]'), Text(0.4823529411764706, 0.39285714285714285, 'x[17] <= -0.783\ngini = 0.444\nsamples = 6\nvalue = [4, 2]'), Text(0.4745098039215686, 0.32142857142857145, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.49019607843137253, 0.32142857142857145, 'gini = 0.0\nsamples = 4\nvalue = [4, 0]'), Text(0.5215686274509804, 0.39285714285714285, 'x[30] <= -0.988\ngini = 0.029\nsamples = 136\nvalue = [134, 2]'), Text(0.5058823529411764, 0.32142857142857145, 'x[9] <= -0.672\ngini = 0.32\nsamples = 5\nvalue = [4, 1]'), Text(0.4980392156862745, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.5137254901960784, 0.25, 'gini = 0.0\nsamples = 4\nvalue = [4, 0]'), Text(0.5372549019607843, 0.32142857142857145, 'x[33] <= -0.698\ngini = 0.015\nsamples = 131\nvalue = [130, 1]'), Text(0.5294117647058824, 0.25, 'x[30] <= 0.769\ngini = 0.124\nsamples = 15\nvalue = [14, 1]'), Text(0.5215686274509804, 0.17857142857142858, 'gini = 0.0\nsamples = 13\nvalue = [13, 0]'), Text(0.5372549019607843, 0.17857142857142858, 'x[9] <= 0.244\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'), Text(0.5294117647058824, 0.10714285714285714, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.5450980392156862, 0.10714285714285714, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.5450980392156862, 0.25, 'gini = 0.0\nsamples = 116\nvalue = [116, 0]'), Text(0.5686274509803921, 0.4642857142857143, 'x[26] <= 2.015\ngini = 0.24\nsamples = 43\nvalue = [37, 6]'), Text(0.5607843137254902, 0.39285714285714285, 'x[29] <= 1.029\ngini = 0.139\nsamples = 40\nvalue = [37, 3]'), Text(0.5529411764705883, 0.32142857142857145, 'gini = 0.0\nsamples = 33\nvalue = [33, 0]'), Text(0.5686274509803921, 0.32142857142857145, 'x[17] <= 1.092\ngini = 0.49\nsamples = 7\nvalue = [4, 3]'), Text(0.5607843137254902, 0.25, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'), Text(0.5764705882352941, 0.25, 'gini = 0.0\nsamples = 4\nvalue = [4, 0]'), Text(0.5764705882352941, 0.39285714285714285, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'), Text(0.7332107843137254, 0.8928571428571429, 'x[17] <= -0.819\ngini = 0.434\nsamples = 302\nvalue = [206, 96]'), Text(0.6470588235294118, 0.8214285714285714, 'x[2] <= 0.317\ngini = 0.441\nsamples = 58\nvalue = [19, 39]'), Text(0.6313725490196078, 0.75, 'x[28] <= 2.112\ngini = 0.208\nsamples = 34\nvalue = [4, 30]'), Text(0.6235294117647059, 0.6785714285714286, 'x[27] <= 0.618\ngini = 0.165\nsamples = 33\nvalue = [3, 30]'), Text(0.615686274509804, 0.6071428571428571, 'x[29] <= -0.371\ngini = 0.117\nsamples = 32\nvalue = [2, 30]'), Text(0.6078431372549019, 0.5357142857142857, 'x[11] <= -0.513\ngini = 0.32\nsamples = 10\nvalue = [2, 8]'), Text(0.6, 0.4642857142857143, 'x[22] <= 0.09\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.592156862745098, 0.39285714285714285, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.6078431372549019, 0.39285714285714285, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.615686274509804, 0.4642857142857143, 'gini = 0.0\nsamples = 7\nvalue = [0, 7]'), Text(0.6235294117647059, 0.5357142857142857, 'gini = 0.0\nsamples = 22\nvalue = [0, 22]'), Text(0.6313725490196078, 0.6071428571428571, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.6392156862745098, 0.6785714285714286, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.6627450980392157, 0.75, 'x[18] <= 0.243\ngini = 0.469\nsamples = 24\nvalue = [15, 9]'), Text(0.6549019607843137, 0.6785714285714286, 'x[31] <= -0.412\ngini = 0.48\nsamples = 15\nvalue = [6, 9]'), Text(0.6470588235294118, 0.6071428571428571, 'x[4] <= -0.45\ngini = 0.375\nsamples = 12\nvalue = [3, 9]'), Text(0.6392156862745098, 0.5357142857142857, 'x[12] <= 1.096\ngini = 0.5\nsamples = 6\nvalue = [3, 3]'), Text(0.6313725490196078, 0.4642857142857143, 'x[2] <= 1.564\ngini = 0.375\nsamples = 4\nvalue = [1, 3]'), Text(0.6235294117647059, 0.39285714285714285, 'gini = 0.0\nsamples = 3\nvalue = [0, 3]'), Text(0.6392156862745098, 0.39285714285714285, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.6470588235294118, 0.4642857142857143, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.6549019607843137, 0.5357142857142857, 'gini = 0.0\nsamples = 6\nvalue = [0, 6]'), Text(0.6627450980392157, 0.6071428571428571, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.6705882352941176, 0.6785714285714286, 'gini = 0.0\nsamples = 9\nvalue = [9, 0]'), Text(0.8193627450980392, 0.8214285714285714, 'x[9] <= -0.214\ngini = 0.358\nsamples = 244\nvalue = [187, 57]'), Text(0.7622549019607843, 0.75, 'x[19] <= 1.151\ngini = 0.473\nsamples = 86\nvalue = [53, 33]'), Text(0.7127450980392157, 0.6785714285714286, 'x[27] <= -0.697\ngini = 0.424\nsamples = 72\nvalue = [50, 22]'), Text(0.6784313725490196, 0.6071428571428571, 'x[33] <= -0.388\ngini = 0.42\nsamples = 10\nvalue = [3, 7]'), Text(0.6705882352941176, 0.5357142857142857, 'x[4] <= -0.949\ngini = 0.219\nsamples = 8\nvalue = [1, 7]'), Text(0.6627450980392157, 0.4642857142857143, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.6784313725490196, 0.4642857142857143, 'gini = 0.0\nsamples = 7\nvalue = [0, 7]'), Text(0.6862745098039216, 0.5357142857142857, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.7470588235294118, 0.6071428571428571, 'x[15] <= 0.706\ngini = 0.367\nsamples = 62\nvalue = [47, 15]'), Text(0.7215686274509804, 0.5357142857142857, 'x[12] <= -0.313\ngini = 0.46\nsamples = 39\nvalue = [25, 14]'), Text(0.6941176470588235, 0.4642857142857143, 'x[6] <= 0.182\ngini = 0.473\nsamples = 13\nvalue = [5, 8]'), Text(0.6862745098039216, 0.39285714285714285, 'gini = 0.0\nsamples = 7\nvalue = [0, 7]'), Text(0.7019607843137254, 0.39285714285714285, 'x[2] <= 0.119\ngini = 0.278\nsamples = 6\nvalue = [5, 1]'), Text(0.6941176470588235, 0.32142857142857145, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.7098039215686275, 0.32142857142857145, 'gini = 0.0\nsamples = 5\nvalue = [5, 0]'), Text(0.7490196078431373, 0.4642857142857143, 'x[6] <= 1.669\ngini = 0.355\nsamples = 26\nvalue = [20, 6]'), Text(0.7411764705882353, 0.39285714285714285, 'x[17] <= -0.668\ngini = 0.278\nsamples = 24\nvalue = [20, 4]'), Text(0.7254901960784313, 0.32142857142857145, 'x[8] <= -0.405\ngini = 0.444\nsamples = 3\nvalue = [1, 2]'), Text(0.7176470588235294, 0.25, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.7333333333333333, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.7568627450980392, 0.32142857142857145, 'x[4] <= 2.418\ngini = 0.172\nsamples = 21\nvalue = [19, 2]'), Text(0.7490196078431373, 0.25, 'x[27] <= -0.542\ngini = 0.095\nsamples = 20\nvalue = [19, 1]'), Text(0.7411764705882353, 0.17857142857142858, 'x[11] <= 0.776\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.7333333333333333, 0.10714285714285714, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.7490196078431373, 0.10714285714285714, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.7568627450980392, 0.17857142857142858, 'gini = 0.0\nsamples = 17\nvalue = [17, 0]'), Text(0.7647058823529411, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.7568627450980392, 0.39285714285714285, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.7725490196078432, 0.5357142857142857, 'x[8] <= 1.446\ngini = 0.083\nsamples = 23\nvalue = [22, 1]'), Text(0.7647058823529411, 0.4642857142857143, 'gini = 0.0\nsamples = 22\nvalue = [22, 0]'), Text(0.7803921568627451, 0.4642857142857143, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.8117647058823529, 0.6785714285714286, 'x[33] <= -0.852\ngini = 0.337\nsamples = 14\nvalue = [3, 11]'), Text(0.796078431372549, 0.6071428571428571, 'x[2] <= 0.757\ngini = 0.444\nsamples = 3\nvalue = [2, 1]'), Text(0.788235294117647, 0.5357142857142857, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.803921568627451, 0.5357142857142857, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.8274509803921568, 0.6071428571428571, 'x[22] <= 1.483\ngini = 0.165\nsamples = 11\nvalue = [1, 10]'), Text(0.8196078431372549, 0.5357142857142857, 'gini = 0.0\nsamples = 10\nvalue = [0, 10]'), Text(0.8352941176470589, 0.5357142857142857, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'), Text(0.8764705882352941, 0.75, 'x[28] <= -1.794\ngini = 0.258\nsamples = 158\nvalue = [134, 24]'), Text(0.8509803921568627, 0.6785714285714286, 'x[11] <= 1.123\ngini = 0.444\nsamples = 9\nvalue = [3, 6]'), Text(0.8431372549019608, 0.6071428571428571, 'gini = 0.0\nsamples = 6\nvalue = [0, 6]'), Text(0.8588235294117647, 0.6071428571428571, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.9019607843137255, 0.6785714285714286, 'x[27] <= -1.315\ngini = 0.212\nsamples = 149\nvalue = [131, 18]'), Text(0.8745098039215686, 0.6071428571428571, 'x[6] <= 0.182\ngini = 0.444\nsamples = 6\nvalue = [2, 4]'), Text(0.8666666666666667, 0.5357142857142857, 'gini = 0.0\nsamples = 4\nvalue = [0, 4]'), Text(0.8823529411764706, 0.5357142857142857, 'gini = 0.0\nsamples = 2\nvalue = [2, 0]'), Text(0.9294117647058824, 0.6071428571428571, 'x[14] <= 0.792\ngini = 0.177\nsamples = 143\nvalue = [129, 14]'), Text(0.8980392156862745, 0.5357142857142857, 'x[11] <= 1.495\ngini = 0.078\nsamples = 99\nvalue = [95, 4]'), Text(0.8823529411764706, 0.4642857142857143, 'x[17] <= -0.746\ngini = 0.042\nsamples = 94\nvalue = [92, 2]'), Text(0.8745098039215686, 0.39285714285714285, 'x[30] <= 0.769\ngini = 0.298\nsamples = 11\nvalue = [9, 2]'), Text(0.8666666666666667, 0.32142857142857145, 'x[4] <= 0.734\ngini = 0.18\nsamples = 10\nvalue = [9, 1]'), Text(0.8588235294117647, 0.25, 'gini = 0.0\nsamples = 9\nvalue = [9, 0]'), Text(0.8745098039215686, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.8823529411764706, 0.32142857142857145, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.8901960784313725, 0.39285714285714285, 'gini = 0.0\nsamples = 83\nvalue = [83, 0]'), Text(0.9137254901960784, 0.4642857142857143, 'x[17] <= -0.706\ngini = 0.48\nsamples = 5\nvalue = [3, 2]'), Text(0.9058823529411765, 0.39285714285714285, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.9215686274509803, 0.39285714285714285, 'gini = 0.0\nsamples = 2\nvalue = [0, 2]'), Text(0.9607843137254902, 0.5357142857142857, 'x[33] <= -1.007\ngini = 0.351\nsamples = 44\nvalue = [34, 10]'), Text(0.9450980392156862, 0.4642857142857143, 'x[15] <= 0.706\ngini = 0.469\nsamples = 8\nvalue = [3, 5]'), Text(0.9372549019607843, 0.39285714285714285, 'gini = 0.0\nsamples = 5\nvalue = [0, 5]'), Text(0.9529411764705882, 0.39285714285714285, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]'), Text(0.9764705882352941, 0.4642857142857143, 'x[16] <= 0.554\ngini = 0.239\nsamples = 36\nvalue = [31, 5]'), Text(0.9686274509803922, 0.39285714285714285, 'gini = 0.0\nsamples = 25\nvalue = [25, 0]'), Text(0.984313725490196, 0.39285714285714285, 'x[32] <= -0.047\ngini = 0.496\nsamples = 11\nvalue = [6, 5]'), Text(0.9764705882352941, 0.32142857142857145, 'x[18] <= 1.295\ngini = 0.245\nsamples = 7\nvalue = [6, 1]'), Text(0.9686274509803922, 0.25, 'gini = 0.0\nsamples = 6\nvalue = [6, 0]'), Text(0.984313725490196, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'), Text(0.9921568627450981, 0.32142857142857145, 'gini = 0.0\nsamples = 4\nvalue = [0, 4]')]
from sklearn.model_selection import GridSearchCV
parameter={
'criterion':['gini','entropy'],
'splitter':['best','random'],
'max_depth':[1,2,3,4,5],
'max_features':['auto', 'sqrt', 'log2']
}
grid_search=GridSearchCV(estimator=decision_tree_model,param_grid=parameter,cv=5,scoring="accuracy")
grid_search.fit(X_train_standardized,y_train)
/usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/tree/_classes.py:269: FutureWarning: `max_features='auto'` has been deprecated in 1.1 and will be removed in 1.3. To keep the past behaviour, explicitly set `max_features='sqrt'`. warnings.warn(
GridSearchCV(cv=5, estimator=DecisionTreeClassifier(random_state=42),
param_grid={'criterion': ['gini', 'entropy'],
'max_depth': [1, 2, 3, 4, 5],
'max_features': ['auto', 'sqrt', 'log2'],
'splitter': ['best', 'random']},
scoring='accuracy')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. GridSearchCV(cv=5, estimator=DecisionTreeClassifier(random_state=42),
param_grid={'criterion': ['gini', 'entropy'],
'max_depth': [1, 2, 3, 4, 5],
'max_features': ['auto', 'sqrt', 'log2'],
'splitter': ['best', 'random']},
scoring='accuracy')DecisionTreeClassifier(random_state=42)
DecisionTreeClassifier(random_state=42)
grid_search.best_params_
{'criterion': 'gini',
'max_depth': 1,
'max_features': 'auto',
'splitter': 'best'}
decision_tree_model_cv=DecisionTreeClassifier(criterion= 'entropy',
max_depth=3,
max_features='sqrt',
splitter='best')
decision_tree_model_cv.fit(X_train_standardized,y_train)
DecisionTreeClassifier(criterion='entropy', max_depth=3, max_features='sqrt')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
DecisionTreeClassifier(criterion='entropy', max_depth=3, max_features='sqrt')
pred=decision_tree_model_cv.predict(X_test_standardized)
print(classification_report(y_test,pred))
precision recall f1-score support
0 0.86 0.96 0.91 229
1 0.47 0.18 0.26 44
accuracy 0.84 273
macro avg 0.66 0.57 0.58 273
weighted avg 0.80 0.84 0.80 273
# Random Forest Model
from sklearn.ensemble import RandomForestClassifier
rfc=RandomForestClassifier()
forest_params = [{'max_depth': list(range(10, 15)), 'max_features': list(range(0,14))}]
rfc_cv= GridSearchCV(rfc,param_grid=forest_params,cv=10,scoring="accuracy")
rfc_cv.fit(X_train_standardized,y_train)
/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py:378: FitFailedWarning:
50 fits failed out of a total of 700.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.
Below are more details about the failures:
--------------------------------------------------------------------------------
50 fits failed with the following error:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py", line 686, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "/usr/local/lib/python3.10/dist-packages/sklearn/ensemble/_forest.py", line 340, in fit
self._validate_params()
File "/usr/local/lib/python3.10/dist-packages/sklearn/base.py", line 600, in _validate_params
validate_parameter_constraints(
File "/usr/local/lib/python3.10/dist-packages/sklearn/utils/_param_validation.py", line 97, in validate_parameter_constraints
raise InvalidParameterError(
sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of RandomForestClassifier must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'auto' (deprecated), 'sqrt', 'log2'} or None. Got 0 instead.
warnings.warn(some_fits_failed_message, FitFailedWarning)
/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_search.py:952: UserWarning: One or more of the test scores are non-finite: [ nan 0.83853211 0.84311927 0.84587156 0.84862385 0.84587156
0.85229358 0.84678899 0.85321101 0.85412844 0.85412844 0.84770642
0.85045872 0.85229358 nan 0.83761468 0.84587156 0.84678899
0.85321101 0.85137615 0.85321101 0.85688073 0.85229358 0.84770642
0.85045872 0.84862385 0.84954128 0.84954128 nan 0.83486239
0.8412844 0.85229358 0.84678899 0.85045872 0.84862385 0.84770642
0.85229358 0.85137615 0.85412844 0.85137615 0.85321101 0.85045872
nan 0.84036697 0.84495413 0.84678899 0.85045872 0.85321101
0.84587156 0.84862385 0.85137615 0.84770642 0.85321101 0.85412844
0.85321101 0.85229358 nan 0.83944954 0.84587156 0.84678899
0.85137615 0.84862385 0.85412844 0.85779817 0.85229358 0.85137615
0.85229358 0.85045872 0.85045872 0.84678899]
warnings.warn(
GridSearchCV(cv=10, estimator=RandomForestClassifier(),
param_grid=[{'max_depth': [10, 11, 12, 13, 14],
'max_features': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13]}],
scoring='accuracy')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. GridSearchCV(cv=10, estimator=RandomForestClassifier(),
param_grid=[{'max_depth': [10, 11, 12, 13, 14],
'max_features': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13]}],
scoring='accuracy')RandomForestClassifier()
RandomForestClassifier()
pred=rfc_cv.predict(X_test_standardized)
print(classification_report(y_test,pred))
precision recall f1-score support
0 0.86 0.99 0.92 229
1 0.73 0.18 0.29 44
accuracy 0.86 273
macro avg 0.79 0.58 0.61 273
weighted avg 0.84 0.86 0.82 273
rfc_cv.best_params_
{'max_depth': 14, 'max_features': 7}
# Combined Performance Metrics Calculations
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
# Calculate performance metrics for Logistic Regression
accuracy_lr = accuracy_score(y_test, y_pred_logistic_regression)
precision_lr = precision_score(y_test, y_pred_logistic_regression)
recall_lr = recall_score(y_test, y_pred_logistic_regression)
f1_score_lr = f1_score(y_test, y_pred_logistic_regression)
confusion_matrix_lr = confusion_matrix(y_test, y_pred_logistic_regression)
# Calculate performance metrics for Decision Tree
accuracy_dt = accuracy_score(y_test, y_pred_decision_tree)
precision_dt = precision_score(y_test, y_pred_decision_tree)
recall_dt = recall_score(y_test, y_pred_decision_tree)
f1_score_dt = f1_score(y_test, y_pred_decision_tree)
confusion_matrix_dt = confusion_matrix(y_test, y_pred_decision_tree)
# Calculate performance metrics for Random Forest
accuracy_rf = accuracy_score(y_test, pred)
precision_rf = precision_score(y_test, pred)
recall_rf = recall_score(y_test, pred)
f1_score_rf = f1_score(y_test, pred)
confusion_matrix_rf = confusion_matrix(y_test, pred)
# Print the performance metrics for each model
print("Logistic Regression Performance Metrics:")
print(f"Accuracy: {accuracy_lr:.2f}")
print(f"Precision: {precision_lr:.2f}")
print(f"Recall: {recall_lr:.2f}")
print(f"F1 Score: {f1_score_lr:.2f}")
print("Confusion Matrix:")
print(confusion_matrix_lr)
print("\nDecision Tree Performance Metrics:")
print(f"Accuracy: {accuracy_dt:.2f}")
print(f"Precision: {precision_dt:.2f}")
print(f"Recall: {recall_dt:.2f}")
print(f"F1 Score: {f1_score_dt:.2f}")
print("Confusion Matrix:")
print(confusion_matrix_dt)
print("\nRandom Forest Performance Metrics:")
print(f"Accuracy: {accuracy_rf:.2f}")
print(f"Precision: {precision_rf:.2f}")
print(f"Recall: {recall_rf:.2f}")
print(f"F1 Score: {f1_score_rf:.2f}")
print("Confusion Matrix:")
print(confusion_matrix_rf)
Logistic Regression Performance Metrics: Accuracy: 0.88 Precision: 0.71 Recall: 0.39 F1 Score: 0.50 Confusion Matrix: [[222 7] [ 27 17]] Decision Tree Performance Metrics: Accuracy: 0.77 Precision: 0.29 Recall: 0.27 F1 Score: 0.28 Confusion Matrix: [[199 30] [ 32 12]] Random Forest Performance Metrics: Accuracy: 0.86 Precision: 0.73 Recall: 0.18 F1 Score: 0.29 Confusion Matrix: [[226 3] [ 36 8]]
Done By: Mudit Sharma - 21BCE2223